import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200

# ------------------------------ fake data and optimizer  ------------------------------
weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)

# ============================ 1 Step LR ===============================
flag = 0
# flag = 1
if flag:
    scheduler_lr = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)  # 设置学习率下降策略

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):
        # 获取当前lr，新版本用 get_last_lr()函数，旧版本用get_lr()函数，具体看UserWarning
        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)

        for i in range(iteration):
            loss = torch.pow((weights-target), 2)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
        scheduler_lr.step()

    plt.plot(epoch_list, lr_list, label="Step LR Scheduler")
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

# ===================================== 2 Multi Step LR ==========================================
flag = 0
# flag = 1
if flag:
    milestones = [50, 125, 160]
    scheduler_lr = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)  # 设置学习率下降策略

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):

        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)

        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
        scheduler_lr.step()

    plt.plot(epoch_list, lr_list, label="Multi Step LR Scheduler\nmilestones:{}".format(milestones))
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

# ===================================== 3 Exponential LR ==========================================
flag = 0
# flag = 1
if flag:
    gamma = 0.95
    scheduler_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)  # 设置学习率下降策略

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):

        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)

        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
        scheduler_lr.step()

    plt.plot(epoch_list, lr_list, label="Exponential LR Scheduler\ngamma:{}".format(gamma))
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

# ===================================== 4 Cosine Annealing LR ==========================================
flag = 0
# flag = 1
if flag:
    t_max = 50
    scheduler_lr = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=0.)  # 设置学习率下降策略

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):

        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)

        for i in range(iteration):
            loss = torch.pow((weights - target), 2)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
        scheduler_lr.step()

    plt.plot(epoch_list, lr_list, label="CosineAnnealingLR Scheduler\nT_max:{}".format(t_max))
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.legend()
    plt.show()

# ===================================== 5 Reduce LR On plateau ==========================================
flag = 0
# flag = 1
if flag:
    loss_value = 0.5
    accuracy = 0.9

    factor = 0.1
    mode = 'min'
    patience = 10
    cooldown = 10
    min_lr = 1e-4
    verbose = True
    scheduler_lr = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, mode=mode, patience=patience,
                                                        cooldown=cooldown, min_lr=min_lr, verbose=verbose)  # 设置学习率下降策略


    for epoch in range(max_epoch):
        for i in range(iteration):

            optimizer.step()
            optimizer.zero_grad()

        if epoch == 5:
            loss_value = 0.4

        scheduler_lr.step(loss_value)

# ===================================== 6 lambda ==========================================
# flag = 0
flag = 1
if flag:
    lr_init = 0.1

    weight_1 = torch.randn((6,3,5,5)) # 可以对不同的参数组 设置不同学习率调整方法
    weight_2 = torch.ones((5,5))

    optimizer = optim.SGD([{'params': [weight_1]}, {'params': [weight_2]}], lr=lr_init)

    lambda1 = lambda epoch: 0.1**(epoch//20)   # 在lambda当中设置调整系数，去更新学习率
    lambda2 = lambda epoch: 0.95**epoch
    # 自定义一个函数，函数的输入是epoch数，返回的是一个调整的系数，这个系数会乘以base_lr,得到下一个epoch的学习率，这就是Lambda可以自定义调整学习率的方法
    # 这个方法最实用的地方在于 设置不同的参数组 有不同学习率调整策略，在模型的fitting当中非常实用

    scheduler_lr = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])  # 设置学习率下降策略

    lr_list, epoch_list = list(), list()
    for epoch in range(max_epoch):
        for i in range(iteration):
            optimizer.step()
            optimizer.zero_grad()
        scheduler_lr.step()

        lr_list.append(scheduler_lr.get_lr())
        epoch_list.append(epoch)
        print('epoch:{:5d}, lr:{}'.format(epoch, scheduler_lr.get_lr()))

    plt.plot(epoch_list, [i[0] for i in lr_list], label="lambda 1")
    plt.plot(epoch_list, [i[1] for i in lr_list], label="lambda 2")
    plt.xlabel("Epoch")
    plt.ylabel("Learning rate")
    plt.title('LambdaLR')
    plt.legend()
    plt.show()